set(
  BASE_HEADERS
  bf16.h
  bf16_math.h
  complex.h
  defines.h
  utils.h
)

function(build_kernel_base TARGET SRCFILE DEPS)
  set(METAL_FLAGS -Wall -Wextra -fno-fast-math -D${MLX_METAL_VERSION})
  if(MLX_METAL_DEBUG)
    set(METAL_FLAGS ${METAL_FLAGS}
        -gline-tables-only
        -frecord-sources)
  endif()
  add_custom_command(
    COMMAND xcrun -sdk macosx metal
                  ${METAL_FLAGS}
                  -c ${SRCFILE}
                  -I${PROJECT_SOURCE_DIR}
                  -o ${TARGET}.air
    DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS}
    OUTPUT ${TARGET}.air
    COMMENT "Building ${TARGET}.air"
    VERBATIM
  )
endfunction(build_kernel_base)

function(build_kernel KERNEL)
  set(SRCFILE ${CMAKE_CURRENT_SOURCE_DIR}/${KERNEL}.metal)
  cmake_path(GET KERNEL STEM TARGET)
  build_kernel_base(${TARGET} ${SRCFILE} "${ARGN}")
  set(KERNEL_AIR ${TARGET}.air ${KERNEL_AIR} PARENT_SCOPE)
endfunction(build_kernel)

build_kernel(arg_reduce)
build_kernel(conv steel/conv/params.h)
build_kernel(gemv steel/utils.h)
build_kernel(gemv_masked steel/utils.h)
build_kernel(layer_norm)
build_kernel(random)
build_kernel(rms_norm)
build_kernel(rope)
build_kernel(
  scaled_dot_product_attention
  scaled_dot_product_attention_params.h
  steel/defines.h
  steel/gemm/transforms.h
  steel/utils.h
)

set(
  STEEL_HEADERS 
  steel/defines.h
  steel/utils.h
  steel/conv/conv.h
  steel/conv/loader.h
  steel/conv/loaders/loader_channel_l.h
  steel/conv/loaders/loader_channel_n.h
  steel/conv/loaders/loader_general.h
  steel/conv/kernels/steel_conv.h
  steel/conv/kernels/steel_conv_general.h
  steel/gemm/gemm.h
  steel/gemm/mma.h
  steel/gemm/loader.h
  steel/gemm/transforms.h
  steel/gemm/kernels/steel_gemm_fused.h
  steel/gemm/kernels/steel_gemm_masked.h
  steel/gemm/kernels/steel_gemm_splitk.h
)

if (NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
build_kernel(binary binary.h binary_ops.h)
build_kernel(binary_two binary_two.h)
build_kernel(copy copy.h)
build_kernel(
  fft
  fft.h
  fft/radix.h
  fft/readwrite.h
)
build_kernel(
  reduce
  atomic.h
  reduction/ops.h
  reduction/reduce_init.h
  reduction/reduce_all.h
  reduction/reduce_col.h
  reduction/reduce_row.h
)
build_kernel(
  quantized
  quantized.h
  ${STEEL_HEADERS}
)
build_kernel(scan scan.h)
build_kernel(softmax softmax.h)
build_kernel(sort sort.h)
build_kernel(ternary ternary.h ternary_ops.h)
build_kernel(unary unary.h unary_ops.h)
build_kernel(
  steel/conv/kernels/steel_conv
  ${STEEL_HEADERS}
)
build_kernel(
  steel/conv/kernels/steel_conv_general
  ${STEEL_HEADERS}
)
build_kernel(
  steel/gemm/kernels/steel_gemm_fused
  ${STEEL_HEADERS}
)
build_kernel(
  steel/gemm/kernels/steel_gemm_masked
  ${STEEL_HEADERS}
)
build_kernel(
  steel/gemm/kernels/steel_gemm_splitk
  ${STEEL_HEADERS}
)
endif()


add_custom_command(
  OUTPUT ${MLX_METAL_PATH}/mlx.metallib
  COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib
  DEPENDS ${KERNEL_AIR}
  COMMENT "Building mlx.metallib"
  VERBATIM
)

add_custom_target(
  mlx-metallib
  DEPENDS
  ${MLX_METAL_PATH}/mlx.metallib
)

add_dependencies(
  mlx
  mlx-metallib
)

# Install metallib
include(GNUInstallDirs)

install(
  FILES ${MLX_METAL_PATH}/mlx.metallib
  DESTINATION ${CMAKE_INSTALL_LIBDIR}
  COMPONENT metallib
)
